"""
Major library for graph dictionary learning algorithms

Created on: July 25, 2022

"""
# pylint: disable=anomalous-backslash-in-string
# pylint: disable=invalid-name
# pylint: disable=missing-function-docstring
# pylint: disable=no-else-return

from time import time
import sys
import os
import pickle
from collections import defaultdict
from multiprocessing import Process, Manager
from ctypes import c_char_p
import heapq
from operator import itemgetter
import shutil
import itertools

import numpy as np
from sklearn import linear_model
from sklearn.utils.extmath import randomized_svd
import scipy
import scipy.sparse as sp
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter

# from utils import load_train_conf, fast_numpy_slicing
from model import MLP
from data_utils import scipy_coo_to_torch_sparse

DEBUG = True
TB = False

if TB:
    import tensorflow as tf
    import tensorboard as tb
    tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

def torch_batch_matrix_mul_matrix_list(a, list_b, batch_size=1024, device="cpu"):
    """
    Compute the matrix multiplication with torch
    input:
    a: the first matrix
    list_b: either a list of matrices or a single matrix
    """
    if isinstance(list_b) is not list:
        list_b = [list_b]
    if batch_size == -1:
        batch_size = a.shape[0]
    max_batch = int(np.ceil(a.shape[0] / batch_size))
    with torch.no_grad():
        list_torch_b = [torch.from_numpy(b).float().to(device) for b in list_b]
        list_res = []
        for i in range(max_batch):
            torch_cur_a = torch.from_numpy(a[i*batch_size:(i+1)*batch_size, :]).float().to(device)
            cur_product = torch_cur_a
            for torch_b in list_torch_b:
                cur_product @= torch_b
            list_res.append(cur_product.cpu().detach().numpy())
    res = np.concatenate(list_res, axis=0)
    return res

def torch_batch_matrix_list_mul_matrix(list_a, b, batch_size=1024, device="cpu"):
    """
    Compute the matrix multiplication with torch
    input:
    list_a: either a list of matrices or a single matrix
    b: the last matrix
    """
    if isinstance(list_a) is not list:
        list_a = [list_a]
    if batch_size == -1:
        batch_size = b.shape[1]
    max_batch = int(np.ceil(b.shape[1] / batch_size))
    with torch.no_grad():
        list_torch_a = [torch.from_numpy(a).float().to(device) for a in list_a]
        list_res = []
        for i in range(max_batch):
            torch_cur_b = torch.from_numpy(b[:, i*batch_size:(i+1)*batch_size]).float().to(device)
            cur_product = torch_cur_b
            for torch_a in reversed(list_torch_a):
                cur_product = torch_a @ cur_product
            list_res.append(cur_product.cpu().detach().numpy())
    res = np.concatenate(list_res, axis=1)
    return res

def np_wthresh(A: np.array, lam: float) -> np.array:
    sign = np.sign(A)
    val = np.abs(A)

    B = A.copy()

    zero_pos = np.where(val<=lam)
    B[zero_pos] =0

    shrink_pos = np.where(val>lam)
    B[shrink_pos] = (val[shrink_pos] - lam) * sign[shrink_pos]

    return B

def torch_wthresh(A: torch.tensor, lam: torch.tensor) -> torch.tensor:
    sign = torch.sign(A)
    abs_val = torch.abs(A)

    opt_thresh = torch.nn.Threshold(0, 0)
    return sign * opt_thresh(abs_val - lam)

class SDMP: # pylint:disable=too-many-instance-attributes
    """
    Main class to learn the sparse decompositio of message passing
    """
    def __init__(self, # pylint: disable=too-many-arguments, dangerous-default-value
                 X,
                 Omega,
                 A,
                 epoch=4,
                 batch_size=64,
                 eval_step=20,
                 eval_batch_size=-1,
                 n_theta_nonzero=20,
                 neighbour_candidate_mode="full",
                 h_base_k=2,
                 h_init_epoch=1,
                 h_hidden=[64],
                 h_loop_cnt=1,
                 h_lr=1.0e-2,
                 h_l2=1.0e-4,
                 h_dropout=0.0,
                 device="cpu",
                 verbose=True):
        self.X = X
        self.Omega = Omega
        self.A = A
        self.epoch = epoch
        self.batch_size = batch_size
        self.eval_step = eval_step
        self.eval_batch_size = eval_batch_size
        self.n_theta_nonzero = n_theta_nonzero
        self.h_base_k = h_base_k
        self.h_init_epoch = h_init_epoch
        self.h_hidden = h_hidden
        self.h_loop_cnt = h_loop_cnt
        self.h_lr = h_lr
        self.h_l2 = h_l2
        self.h_dropout = h_dropout
        self.neighbour_candidate_mode = neighbour_candidate_mode
        self.device = device
        self.verbose = verbose

        # initialize parameters
        self.log = defaultdict(list)
        self.data_size, self.feature_dim = self.X.shape
        _, self.GNN_encoder_dim = self.Omega.shape
        self.global_iter_cnt = None
        self._iter_cnt, self._sample_idx = None, None
        self._max_iter = int(np.ceil(self.data_size / self.batch_size))
        if self.eval_batch_size == -1:
            self.eval_batch_size = self.data_size

        self.torch_features = None
        self.ThetaT, self.H = None, None # ThetaT consistent with paper notation
        ## initialize h function
        self.h = MLP(self.feature_dim, self.h_hidden, self.GNN_encoder_dim).to(self.device)
        ## initialize optimizer
        self.h_opt = torch.optim.Adam(self.h.parameters(),
                                      lr=self.h_lr,
                                      weight_decay=self.h_l2)
        self.h_loss_func = nn.MSELoss()

        # precomputing
        if self.neighbour_candidate_mode == 'sparse':
            self.neighbour_candidate = self.gen_neighbour_candidate()

    def gen_neighbour_candidate(self):
        """
        Generate the neighbour candidate according to some predefined heuristics
        """
        lil_A = sp.lil_matrix(self.A)
        return lil_A.rows
        # tmp = np.arange(self.data_size).astype(int)
        # return [tmp for _ in range(self.data_size)]

    def _sampler(self):
        """handle the random mini-batch indices"""
        # initialize the parameters
        self._sample_idx = np.arange(self.data_size)
        np.random.shuffle(self._sample_idx)
        # main loop
        for self._iter_cnt in range(self._max_iter):
            yield self._sample_idx[self._iter_cnt*self.batch_size:
                                   min((self._iter_cnt+1)*self.batch_size, self.data_size)]

    def _init_Theta(self):
        ThetaT = sp.lil_matrix((self.data_size, self.data_size), dtype='float')
        return ThetaT

    def get_torch_H(self, idx=None, eval_batch_size=64, local_verbose=True):
        tic = time()
        if idx is None:
            idx = np.arange(self.data_size)
        max_batch = int(np.ceil(len(idx)/eval_batch_size))
        self.h.eval() # set to eval mode!
        list_H = []
        if self.verbose and DEBUG and local_verbose:
            print("Started inferencing H...")
        for bb in range(max_batch):
            cur_idx = idx[bb*eval_batch_size:
                        min((bb+1)*eval_batch_size, self.data_size)]
            cur_X = torch.from_numpy(
                self.X[cur_idx, :])\
                .float().to(self.device)
            list_H.append(self.h(cur_X)) # pylint: disable=not-callable
            if self.verbose and DEBUG and local_verbose:
                cur_time = time() - tic
                ETA = cur_time / (bb+1) * (max_batch - bb - 1)
                sys.stdout.write(f"{bb+1}/{max_batch} finished in {cur_time:.1f}s, "
                                    f"ETA {ETA:.1f}\r")
                sys.stdout.flush()
        if self.verbose and DEBUG and local_verbose:
            print()
        H = torch.cat(list_H, dim=0)
        return H

    def eval_metrics_torch(self, local_verbose=True):
        if self.verbose and DEBUG and local_verbose:
            print("Evaluating metrics...")
        torch_H = self.get_torch_H(eval_batch_size=self.eval_batch_size,
                                   local_verbose=local_verbose)

        if self.verbose and DEBUG and local_verbose:
            print("Started evaluating the metrics...")
        list_diff, list_norm = [], []
        max_batch = int(np.ceil(self.data_size/self.eval_batch_size))

        with torch.no_grad():
            tic = time()
            for bb in range(max_batch):
                cur_ThetaT = scipy_coo_to_torch_sparse(
                    self.ThetaT[bb*self.eval_batch_size:
                               min((bb+1)*self.eval_batch_size, self.data_size), :].tocoo())\
                    .float().to(self.device)
                ThetaT_dot_H = torch.sparse.mm(cur_ThetaT, torch_H)
                cur_Omega = torch.from_numpy(self.Omega[bb*self.eval_batch_size:
                                                        min((bb+1)*self.eval_batch_size,
                                                            self.data_size), :])\
                                .float().to(self.device)
                diff_row_square_sum = torch.norm(ThetaT_dot_H - cur_Omega, dim=1) ** 2
                self_norm_square_sum = torch.norm(cur_Omega, dim=1) ** 2

                list_diff.append(diff_row_square_sum.cpu().detach().numpy())
                list_norm.append(self_norm_square_sum.cpu().detach().numpy())
                if self.verbose and DEBUG and local_verbose:
                    cur_time = time() - tic
                    ETA = cur_time / (bb+1) * (max_batch - bb - 1)
                    sys.stdout.write(f"{bb+1}/{max_batch} finished in {cur_time:.1f}s, "
                                     f"ETA {ETA:.1f}\r")
                    sys.stdout.flush()
                # end of main loop
            if self.verbose and DEBUG and local_verbose:
                print()
            diff = np.concatenate(list_diff)
            norm = np.concatenate(list_norm)
            regret = np.sum(diff)
            rel_regret = np.sum(diff/norm) / self.data_size
            return regret, rel_regret

    def eval_and_log(self):
        this_regret, this_rel_regret = self.eval_metrics_torch(local_verbose=False)

        self.log["regret"].append(this_regret)
        self.log["rel_regret"].append(this_rel_regret)
        self.log["Theta_col_nonzero_cnt"].append(np.sum(self.ThetaT!=0, axis=1))
        self.log["Theta_col_nonzero_stat"].append(
            (np.mean(self.log["Theta_col_nonzero_cnt"][-1]),
             np.std(self.log["Theta_col_nonzero_cnt"][-1]))
        )

    def _display_stat(self, i=-1):
        print(f"    Regret: {self.log['regret'][i]:.4f}"
              f" | Rel Regret: {self.log['rel_regret'][i]:.4f}"
              f" || Theta col nonzeros: {self.log['Theta_col_nonzero_stat'][i][0]:.3f} " + u"\u00B1"
              f" {self.log['Theta_col_nonzero_stat'][i][1]:.3f}")

    def _display_more_stat(self, i=-1):
        print(f"    Times: Theta all {self.log['time_Theta'][i]:.1f}"
              f" | Theta pre {self.log['time_Theta_preprocess'][i]:.1f}"
              f" | Theta lar {self.log['time_Theta_lar'][i]:.1f}"
              f" | h all {self.log['time_h'][i]:.1f}")

    def fit(self): # pylint: disable=too-many-statements
        if self.verbose:
            print("Initializing...")
        self.ThetaT = self._init_Theta()
        self.eval_and_log()
        if self.verbose:
            print("Training started...")
            self._display_stat()
            print("-"*27, end="\n\n")
        # preprocessing
        tic_pre = time()
        if self.verbose:
            print("Preprocessing...")
        self.init_process_h()
        self.log['time_preprocess'].append(time()-tic_pre)
        if self.verbose:
            print(f"Preprocessing finished in {self.log['time_preprocess'][-1]:.1f}")
        self.eval_and_log()
        if self.verbose:
            self._display_stat()
            print("Main loop begins...")
            print("-"*27, end="\n\n")
        # main loop
        self.global_iter_cnt = 0
        total_iter = self.epoch * self._max_iter
        tic_start = time()
        for e in range(self.epoch):
            for it, cur_idx in enumerate(self._sampler()):
                # tic_iter_start = time()
                self.global_iter_cnt += 1
                # phase Theta
                tic_Theta_start = time()
                self.update_Theta(cur_idx)
                self.log['time_Theta'].append(time()-tic_Theta_start)
                # phase h
                tic_h_start = time()
                self.update_h(cur_idx)
                self.log['time_h'].append(time()-tic_h_start)
                # evaluation and display
                if self.global_iter_cnt % self.eval_step ==0:
                    self.eval_and_log()
                    if self.verbose:
                        elapsed_time = time() - tic_start
                        ETA = elapsed_time /\
                            self.global_iter_cnt * (total_iter - self.global_iter_cnt)
                        print("-"*5)
                        print(f"Epoch: {e} | Iter: {it} | Global iter: {self.global_iter_cnt} "
                              f"Elapsed time: {elapsed_time:.1f} | ETA: {ETA:.1f}")
                        self._display_stat()
                        if DEBUG: # print more info for debugging mode
                            self._display_more_stat()
                            # break

        # Post update
        if self.verbose:
            print("-"*27)
            print()
            print("Started post processing...")
        self.post_update()
        self.eval_and_log()
        if self.verbose:
            print(f"Training finished in {(time()-tic_start):.1f} s.")
            self._display_stat()
        return self
        # end fit

    # @profile
    def update_Theta(self, cur_idx, log=True):
        time_preprocessing, time_lar = 0.0, 0.0
        # prepare the data
        tic = time()
        receptive_idx, local_map = self.get_batch_Theta_candidate(cur_idx)
        torch_cur_H = self.get_torch_H(idx=receptive_idx, local_verbose=False)
        # full_gram = torch.mm(torch_cur_H, torch_cur_H.t()).cpu().detach().numpy()
        time_preprocessing += time() - tic
        # execute LAR and update the results
        for i, i_local_map in zip(cur_idx, local_map):
            tic_i_pre_start = time()
            i_target = self.Omega[np.array([i]), :]
            torch_i_H = torch_cur_H[i_local_map, :]
            # my_gram = full_gram[i_local_map, :][:, i_local_map]
            # my_gram = fast_numpy_slicing(full_gram, i_local_map, i_local_map)
            my_gram = torch.mm(torch_i_H, torch_i_H.t()).cpu().detach().numpy()
            X = torch_i_H.cpu().detach().numpy().transpose()
            y = i_target.transpose()
            Xy = torch.mm(torch_i_H, torch.from_numpy(y).to(self.device))
            Xy = Xy.cpu().detach().numpy()
            time_preprocessing += time() - tic_i_pre_start
            # execute Lars
            tic_lars_start = time()
            reg = linear_model.LassoLars(alpha=0, precompute=my_gram, max_iter=self.n_theta_nonzero,
                                         normalize=False, fit_intercept=False, positive=True)
            reg.fit(X=X, y=y, Xy=Xy)
            # collect the results and index mapping
            local_res = np.array(reg.coef_)
            local_res_coo = sp.coo_matrix(local_res)
            self.ThetaT.rows[i] = receptive_idx[i_local_map[local_res_coo.col]].tolist()
            self.ThetaT.data[i] = local_res_coo.data.tolist()
            time_lar += time() - tic_lars_start
        if log:
            self.log['time_Theta_preprocess'].append(time_preprocessing)
            self.log['time_Theta_lar'].append(time_lar)

    def update_h(self, cur_idx):
        loss = self.train_h_batch(cur_idx, self.ThetaT)
        self.log['h_loss'].append(loss)

    def get_batch_Theta_candidate(self, cur_idx):
        if self.neighbour_candidate_mode == "full":
            receptive_idx = np.arange(self.data_size).astype(int)
            aligned_local_map = [receptive_idx for _ in range(len(cur_idx))]
        elif self.neighbour_candidate_mode == "sparse":
            cur_candidate = [self.neighbour_candidate[i] for i in cur_idx]
            pivot = [0] + [len(i) for i in cur_candidate]
            pivot = np.cumsum(pivot)
            flat_candidate = list(itertools.chain.from_iterable(cur_candidate))
            receptive_idx, local_map = np.unique(flat_candidate, return_inverse=True)
            receptive_idx = receptive_idx.astype(int)
            aligned_local_map = [local_map[pivot[i]:pivot[i+1]] for i in range(len(pivot)-1)]
        else:
            raise ValueError("Unrecoganized neighbour candidate mode"
                             f"{self.neighbour_candidate_mode}.")

        return receptive_idx, aligned_local_map

    def train_h_batch(self, batch_idx, ThetaT):
        self.h.train()
        # prepare samples
        cur_ThetaT = ThetaT[batch_idx, :].tocoo()
        X_ind, local_map = np.unique(cur_ThetaT.col, return_inverse=True)
        cur_ThetaT_local = sp.coo_matrix((cur_ThetaT.data, (cur_ThetaT.row, local_map)),
                                         shape=(cur_ThetaT.shape[0],len(X_ind)))
        cur_ThetaT_local = scipy_coo_to_torch_sparse(cur_ThetaT_local).to(self.device)
        target = torch.from_numpy(self.Omega[batch_idx, :]).to(self.device)
        # construct the loss
        H = self.get_torch_H(idx=X_ind, local_verbose=False)
        pred = torch.sparse.mm(cur_ThetaT_local, H)
        loss = self.h_loss_func(pred, target)
        # execute the training
        self.h_opt.zero_grad()
        loss.backward()
        self.h_opt.step()
        return loss.item()

    def init_process_h(self, local_verbose=True):
        """
        Initialize the h by training several epochs of h based on graphs
        """
        init_ThetaT = self.A.transpose().tolil()
        max_batch = self.h_init_epoch * self._max_iter
        bb = 0
        tic = time()
        for _ in range(self.h_init_epoch):
            for _, cur_idx in enumerate(self._sampler()):
                cur_loss = self.train_h_batch(cur_idx, init_ThetaT)
                self.log["h_loss"].append(cur_loss)
                if self.verbose and local_verbose:
                    cur_time = time() - tic
                    ETA = cur_time / (bb+1) * (max_batch - bb - 1)
                    sys.stdout.write(f"{bb+1}/{max_batch} finished in {cur_time:.1f}s, "
                                     f"ETA {ETA:.1f}\r")
                    sys.stdout.flush()
                bb += 1
                # end of main loop
        if self.verbose and local_verbose:
            print()

    def post_update(self):
        if self.verbose:
            print("Post precessing... Recomputing Theta...")
        tic_start = time()
        for it, cur_idx in enumerate(self._sampler()):
            self.update_Theta(cur_idx, log=False)
            if self.verbose:
                cur_time = time() - tic_start
                ETA = cur_time / (it + 1) * (self._max_iter - it - 1)
                print(f'{it/self._max_iter:.3f} finished in {cur_time:.1f} s. ETA: {ETA:.1f} s',
                      end='\r', flush=True)
        if self.verbose:
            print()

    def infer_torch_node_approximal_features(self):
        # infer H
        H = self.get_torch_H()
        # get final results
        torch_ThetaT = scipy_coo_to_torch_sparse(self.ThetaT.tocoo()).to(self.device)
        self.torch_features = torch.sparse.mm(torch_ThetaT, H).detach()
        return self.torch_features

    def save(self, res_folder):
        with open(os.path.join(res_folder, "ThetaT.pkl"), 'wb') as fout:
            pickle.dump(self.ThetaT, fout)
        with open(os.path.join(res_folder, "log.pkl"), "wb") as fout:
            pickle.dump(self.log, fout)
        with open(os.path.join(res_folder, "h_model_stat.pkl"), "wb") as fout:
            pickle.dump(self.h.state_dict(), fout)
        if self.verbose:
            print(f"Results are saved in {res_folder}.")

    def load(self, res_folder):
        if self.verbose:
            print(f"Loading results from {res_folder}.")
        with open(os.path.join(res_folder, "ThetaT.pkl"), 'rb') as fin:
            self.ThetaT = pickle.load(fin)
        with open(os.path.join(res_folder, "log.pkl"), "rb") as fin:
            self.log = pickle.load(fin)
        with open(os.path.join(res_folder, "h_model_stat.pkl"), "rb") as fin:
            self.h.load_state_dict(pickle.load(fin))

class SDMP_dict: # pylint: disable=too-many-instance-attributes
    """
    Main class to learn the sparse decomposition of message passing with dictionary learning
    structure
    Inputs:
      self_feature: how to deal with the self features. If "extend_base", the last element of
                    theta will be extended for the base refering to the self features.
    """
    NORMAL_VAR = 1.0e-6 # small value for introduce Gaussion noise
    TINY = 1e-9 # small value to prevent numerical overflow
    UTIL_THRESH = 1.0e-2
    def __init__(self, # pylint: disable=too-many-arguments
                 H,
                 Omega,
                 A,
                 dict_size,
                 epoch=5,
                 batch_size=64,
                 eval_step=20,
                 eval_batch_size=-1,
                 self_feature="default",
                 n_theta_nonzero=20,
                 lam2=1e-3,
                 O_loop_cnt=1,
                 O_resample_method="no",
                 O_warmup=40,
                 O_resample_step=20,
                 HHT=None,
                 HHTHHT=None,
                 HOmegaT=None,
                 F_HHT=None,
                 ori_target_square_norm=None,
                 device="cpu",
                 verbose=True):
        self.dict_size = dict_size
        self.epoch = epoch
        self.batch_size = batch_size
        self.eval_step = eval_step
        self.eval_batch_size = eval_batch_size
        self.self_feature=self_feature
        self.n_theta_nonzero = n_theta_nonzero
        self.lam2 = lam2
        self.O_loop_cnt = O_loop_cnt
        self.O_resample_method = O_resample_method
        self.O_warmup = O_warmup
        self.O_resample_step = O_resample_step
        self.HHT = HHT
        self.HHTHHT = HHTHHT
        self.HOmegaT = HOmegaT
        self.F_HHT = F_HHT
        self.ori_target_square_norm = ori_target_square_norm
        self.device = device
        self.verbose = verbose

        self.H = H
        self.Omega = Omega
        self.A = A

        if TB:
            log_dir = "runs/tests"
            shutil.rmtree(log_dir)
            self.writer = SummaryWriter(log_dir)

        if HHT is None:
            self.HHT = torch_batch_matrix_mul_matrix_list(self.H, self.H.transpose(),
                                                          batch_size=self.eval_batch_size,
                                                          device=self.device)
        if HHTHHT is None:
            tmp = torch_batch_matrix_mul_matrix_list(self.H.transpose(),
                                                     self.H,
                                                     batch_size=self.eval_batch_size,
                                                     device=self.device)
            self.HHTHT = torch_batch_matrix_mul_matrix_list(self.H, [tmp, self.H.transpose()],
                                                            batch_size=self.eval_batch_size,
                                                            device=self.device)
        if HOmegaT is None:
            self.HOmegaT = torch_batch_matrix_mul_matrix_list(self.H,
                                                              self.Omega.transpose(),
                                                              batch_size=self.eval_batch_size,
                                                              device=self.device)
        if F_HHT is None:
            self.F_HHT = np.linalg.norm(self.HHT)

        # initialize parameters
        self.log = defaultdict(list)
        self.data_size, self.basis_dim = self.A.shape # dictionary basis dimension and the data size
        _, self.feature_dim = self.H.shape
        self.global_iter_cnt = None
        self._iter_cnt, self._sample_idx = None, None
        self._max_iter = int(np.ceil(self.data_size / self.batch_size))
        self.per_sample_regret = None # per sample regret to pick resampled O
        self.O, self.Theta = None, None

        self.util_thresh = SDMP_dict.UTIL_THRESH * self.data_size / self.dict_size

    def _sampler(self):
        """handle the random mini-batch indices"""
        # initialize the parameters
        self._sample_idx = np.arange(self.data_size)
        np.random.shuffle(self._sample_idx)
        # main loop
        for self._iter_cnt in range(self._max_iter):
            yield self._sample_idx[self._iter_cnt*self.batch_size:
                                   min((self._iter_cnt+1)*self.batch_size, self.data_size)]

    def _init_O_Theta(self):
        # TODO: add pre-defined initialization of O. If larger than expected,
        # sample donw, if smaller, fill with random rows of A.
        seed = np.random.choice(list(range(self.data_size)), self.dict_size)
        O = self.A[seed]
        if not isinstance(self.A, np.ndarray):
            O = O.toarray()
        O += np.random.normal(0, SDMP_dict.NORMAL_VAR, O.shape)
        Theta = np.zeros([self.dict_size, self.data_size])
        if DEBUG:
            self.log['O_init'].append(O.copy())

        return O, Theta

    def eval_metrics_torch(self, local_verbose=True):
        if self.verbose and DEBUG and local_verbose:
            print("Evaluating...")
        list_diff, list_norm = [], []
        max_batch = int(np.ceil(self.data_size/self.eval_batch_size))

        with torch.no_grad():
            tch_O = torch.from_numpy(self.O).float().to(self.device)
            tch_H = torch.from_numpy(self.H).float().to(self.device)
            tic = time()
            tch_OH = torch.mm(tch_O, tch_H)
            for bb in range(max_batch):
                cur_Theta = torch.from_numpy(
                    self.Theta[:, bb*self.eval_batch_size:
                               min((bb+1)*self.eval_batch_size, self.data_size)])\
                    .float().to(self.device)
                ThetaT_dot_O_dot_H = torch.mm(cur_Theta.t(), tch_OH)
                cur_Omega = torch.from_numpy(self.Omega[bb*self.eval_batch_size:
                                                        min((bb+1)*self.eval_batch_size,
                                                            self.data_size), :])\
                                .float().to(self.device)
                diff_row_square_sum = torch.norm(ThetaT_dot_O_dot_H - cur_Omega, dim=1) ** 2
                self_norm_square_sum = torch.norm(cur_Omega, dim=1) ** 2

                list_diff.append(diff_row_square_sum.cpu().detach().numpy())
                list_norm.append(self_norm_square_sum.cpu().detach().numpy())
                if self.verbose and DEBUG and local_verbose:
                    cur_time = time() - tic
                    ETA = cur_time / (bb+1) * (max_batch - bb - 1)
                    sys.stdout.write(f"{bb+1}/{max_batch} finished in {cur_time:.1f}s, "
                                     f"ETA {ETA:.1f}\r")
                    sys.stdout.flush()
                # end of main loop
            if self.verbose and DEBUG and local_verbose:
                print()
            diff = np.concatenate(list_diff)
            norm = np.concatenate(list_norm)
            regret = np.sum(diff)
            # use this to measure the final performance
            per_sample_regret = diff / self.ori_target_square_norm
            rel_regret = np.sum(diff/norm) / self.data_size
            rel_weighted_regret = np.sum(diff/self.ori_target_square_norm) / self.data_size
            return regret, rel_regret, rel_weighted_regret, per_sample_regret

    def eval_and_log(self):
        this_regret, this_rel_regret, this_weighted_rel_regret, self.per_sample_regret =\
            self.eval_metrics_torch(local_verbose=False)

        self.log["regret"].append(this_regret)
        self.log["rel_regret"].append(this_rel_regret)
        self.log["weighted_rel_regret"].append(this_weighted_rel_regret)
        self.log["Theta_col_nonzero_cnt"].append(np.sum(self.Theta!=0, axis=0))
        self.log["O_row_nonzero_cnt"].append(np.sum(self.O!=0, axis=1))
        self.log["Theta_col_nonzero_stat"].append(
            (np.mean(self.log["Theta_col_nonzero_cnt"][-1]),
             np.std(self.log["Theta_col_nonzero_cnt"][-1]))
        )
        self.log["O_row_nonzero_stat"].append(
            (np.mean(self.log["O_row_nonzero_cnt"][-1]),
             np.std(self.log["O_row_nonzero_cnt"][-1]))
        )
    def _display_stat(self, i=-1):
        print(f"    Regret: {self.log['regret'][i]:.4f}"
              f" | Rel Regret: {self.log['rel_regret'][i]:.4f}"
              f" | W Rel Regret: {self.log['weighted_rel_regret'][i]:.4f}"
              f" || Theta col nonzeros: {self.log['Theta_col_nonzero_stat'][i][0]:.3f} " + u"\u00B1"
              f" {self.log['Theta_col_nonzero_stat'][i][1]:.3f}"
              f" | O row nonzeros: {self.log['O_row_nonzero_stat'][i][0]:.3f} " + u"\u00B1"
              f" {self.log['O_row_nonzero_stat'][i][1]:.3f}")

    def _display_more_stat(self, i=-1):
        print(f"    Times: Theta all {self.log['time_Theta'][i]:.1f}"
              f" | Theta pre {self.log['time_Theta_pre'][i]:.1f}"
              f" | Theta lasso {self.log['time_Theta_lasso'][i]:.1f}"
              f" | O all {self.log['time_O'][i]:.1f}"
              f" | O resample {self.log['time_resample_O'][i]:.1f}")
        if self.global_iter_cnt > self.O_warmup:
            try:
                print(f" | O pre {self.log['time_O_pre'][i]:.1f}"
                      f" | O update {self.log['time_O_update'][i]:.1f}.", end=" ")
            except: # pylint: disable=bare-except
                pass
            try:
                resample_v_eval = int(self.eval_step / self.O_resample_step)
                print(f" | O resample execute pre {self.log['time_resample_O_execute'][i]:.1f}"
                      f" | O update {self.log['time_resample_O_post'][i]:.1f}."
                      " | O resampled count "
                      f"{np.sum([len(j) for j in self.log['O_resampled'][i-resample_v_eval:i]]):d}")
            except: # pylint: disable=bare-except
                pass

    def fit(self): # pylint: disable=too-many-branches, too-many-statements
        """
        main procedure of fitting
        """
        # initialization
        if self.verbose:
            print("Initializing...")
        self.O, self.Theta = self._init_O_Theta()

        self.global_iter_cnt = 0
        total_iter = self.epoch * self._max_iter

        #main loop
        self.eval_and_log()
        if self.verbose:
            print("Training started...")
            self._display_stat()
            print("-"*27)
            print()

        if TB:
            self.writer.add_embedding(self.O, global_step=self.global_iter_cnt, tag="O")
            self.writer.add_histogram("tch_O", self.O, global_step=self.global_iter_cnt, bins=100)
            self.writer.add_scalars("O_row",
                                    {str(indx): val for indx, val in enumerate(self.O[0, :200])},
                                    global_step=self.global_iter_cnt)

        tic_start = time()
        for e in range(self.epoch):
            for it, cur_idx in enumerate(self._sampler()):
                # tic_iter_start = time()
                self.global_iter_cnt += 1
                # phase Theta
                tic_Theta_start = time()
                self.update_Theta(cur_idx)
                self.log["time_Theta"].append(time()-tic_Theta_start)
                # phase O start when warm up finished
                tic_O_start = time()
                if self.global_iter_cnt > self.O_warmup:
                    if self.global_iter_cnt == self.O_warmup + 1 and self.verbose:
                        print("---Warmup finished---Start to update O.")
                    # self.update_O_BCD_Lasso()
                    self.update_O_proximal_gradient()
                self.log["time_O"].append(time()-tic_O_start)
                tic_O_resample = time()
                if self.global_iter_cnt > self.O_warmup and\
                    self.global_iter_cnt % self.O_resample_step == 0:
                    self.resample_O(self.O_resample_method)
                self.log["time_resample_O"].append(time()-tic_O_resample)
                if self.global_iter_cnt % self.eval_step ==0:
                    self.eval_and_log()
                    if self.verbose:
                        elapsed_time = time() - tic_start
                        ETA = elapsed_time /\
                            self.global_iter_cnt * (total_iter - self.global_iter_cnt)
                        print("-"*5)
                        print(f"Epoch: {e} | Iter: {it} | Global iter: {self.global_iter_cnt} "
                              f"Elapsed time: {elapsed_time:.1f} | ETA: {ETA:.1f}")
                        self._display_stat()
                        if DEBUG: # print more info for debugging mode
                            self._display_more_stat()
                    if TB:
                        self.writer.add_scalar('regret', self.log['regret'][-1],
                                               self.global_iter_cnt / self.eval_step)
                if TB:
                    self.writer.add_embedding(self.Theta, global_step=self.global_iter_cnt,
                                              tag="Theta")
                    self.writer.add_histogram("Theta", self.Theta,
                                              global_step=self.global_iter_cnt, bins=100)
            # end main loop
        # Post update
        if self.verbose:
            print("-"*27)
            print()
            print("Started post processing...")
        self.post_update()
        self.eval_and_log()
        if self.verbose:
            print(f"Training finished in {(time()-tic_start):.1f} s.")
            self._display_stat()

        if TB:
            self.writer.close()
        return self
        # end fit

    def fit_debug(self): # pylint: disable=too-many-branches, too-many-statements
        """
        main procedure of fitting
        """
        print("!!!!!!!!!! DEBUGGING FITTING!!!!!!!!!!")
        # initialization
        if self.verbose:
            print("Initializing...")
        self.O, self.Theta = self._init_O_Theta()
        print("Loading Theta and O directly..")
        prefix="densetheta"
        self.Theta = np.load("res/test/theta_"+prefix+".npy")
        self.O = np.load("res/test/O_"+prefix+".npy")
        self.log["time_Theta"].append(0.0)
        self.log["time_Theta_pre"].append(0.0)
        self.log["time_Theta_lasso"].append(0.0)

        self.global_iter_cnt = 0
        total_iter = self.epoch * self._max_iter

        #main loop
        self.eval_and_log()
        if self.verbose:
            print("Training started...")
            self._display_stat()
            print("-"*27)
            print()

        if TB:
            self.writer.add_embedding(self.O, global_step=self.global_iter_cnt, tag="O")
            self.writer.add_histogram("tch_O", self.O, global_step=self.global_iter_cnt, bins=100)
            self.writer.add_scalars("O_row",
                                    {str(indx): val for indx, val in enumerate(self.O[0, :200])},
                                    global_step=self.global_iter_cnt)

        tic_start = time()
        for e in range(self.epoch):
            for it, _ in enumerate(self._sampler()):
                # tic_iter_start = time()
                self.global_iter_cnt += 1
                # phase Theta
                tic_Theta_start = time()
                # self.update_Theta(cur_idx)
                self.log["time_Theta"].append(time()-tic_Theta_start)
                # phase O start when warm up finished
                tic_O_start = time()
                # self.update_O_BCD_lasso()
                self.update_O_proximal_gradient()
                self.log["time_O"].append(time()-tic_O_start)
                tic_O_resample = time()
                if self.global_iter_cnt > self.O_warmup and e > 0 and\
                    self.global_iter_cnt % self.O_resample_step == 0:
                    self.resample_O(self.O_resample_method)
                self.log["time_resample_O"].append(time()-tic_O_resample)
                if self.global_iter_cnt % self.eval_step ==0:
                    self.eval_and_log()
                    if self.verbose:
                        elapsed_time = time() - tic_start
                        ETA = elapsed_time /\
                            self.global_iter_cnt * (total_iter - self.global_iter_cnt)
                        print("-"*5)
                        print(f"Epoch: {e} | Iter: {it} | Global iter: {self.global_iter_cnt} "
                              f"Elapsed time: {elapsed_time:.1f} | ETA: {ETA:.1f}")
                        self._display_stat()
                        if DEBUG: # print more info for debugging mode
                            self._display_more_stat()
                    if TB:
                        self.writer.add_scalar('regret', self.log['regret'][-1],
                                               self.global_iter_cnt / self.eval_step)
                if TB:
                    self.writer.add_embedding(self.Theta,
                                              global_step=self.global_iter_cnt, tag="Theta")
                    self.writer.add_histogram("Theta", self.Theta,
                                              global_step=self.global_iter_cnt, bins=100)
                    ## for showing the utility of rows of O
                    ut = np.sum(self.Theta * self.Theta, axis=1)
                    self.writer.add_scalars("O_row_util",
                                            {str(indx): val for indx, val in enumerate(ut)},
                                            global_step=self.global_iter_cnt)
            # end main loop
        # Post update
        if self.verbose:
            print("-"*27)
            print()
            print("Started post processing...")
        self.post_update()
        self.eval_and_log()
        if self.verbose:
            print(f"Training finished in {(time()-tic_start):.1f} s.")
            self._display_stat()
        if TB:
            self.writer.close()
        return self
        # end fit

    def update_Theta(self, cur_idx, log=True):
        # pre-computations
        ## gram OHH'O', with O c by n, H n by f
        ### method 1: OH H' O'
        ###     complexity cnf + cfn + cnc
        ### method 2: O (HH') O'
        ###     complexity cnn + cnc
        ### diff = nc(2f - n)
        tic = time()
        OH = torch_batch_matrix_mul_matrix_list(self.O, self.H,
                                                batch_size=self.eval_batch_size,
                                                device=self.device)
        gram = torch_batch_matrix_mul_matrix_list(OH, OH.transpose(),
                                                  batch_size=self.eval_batch_size,
                                                  device=self.device)
        ## Xy
        tic = time()
        cur_target = self.Omega[cur_idx, :]
        Xy = torch_batch_matrix_mul_matrix_list(OH, cur_target.transpose(),
                                                batch_size=self.eval_batch_size,
                                                device=self.device)
        if log:
            self.log["time_Theta_pre"].append(time()-tic)
        tic = time()
        X = OH.transpose()
        y = cur_target.transpose()
        reg = linear_model.LassoLars(alpha=0, precompute=gram, max_iter=self.n_theta_nonzero,
                                     normalize=False, fit_intercept=False)
        reg.fit(X=X, y=y, Xy=Xy)
        try:
            self.Theta[:, cur_idx] = np.array(reg.coef_).transpose()
        except: # for the case of single id # pylint: disable=bare-except
            self.Theta[:, cur_idx] = np.array(reg.coef_).reshape(self.Theta[:, cur_idx].shape)
        if log:
            self.log["time_Theta_lasso"].append(time()-tic)

    def update_O_BCD_lasso(self):
        t_O_pre, t_O_update = 0., 0.
        with torch.no_grad():
            tic = time()
            # outer preprocessing
            ThetaThetaT = torch_batch_matrix_mul_matrix_list(self.Theta, self.Theta.transpose(),
                                                             batch_size=self.eval_batch_size,
                                                             device=self.device)
            tch_ThetaThetaT = torch.from_numpy(ThetaThetaT).to(self.device)
            tch_Omega = torch.from_numpy(self.Omega).to(self.device)
            tch_O = torch.from_numpy(self.O).to(self.device)
            tch_H = torch.from_numpy(self.H).to(self.device)
            t_O_pre += time() - tic
            for _ in range(self.O_loop_cnt):
                for j in range(self.dict_size):
                    # prepare the gram and Xy
                    tic = time()
                    c_z = tch_ThetaThetaT[j, j] + SDMP_dict.TINY
                    ## gram
                    gram = self.HHTHHT
                    ## Xy
                    tch_Theta_j = torch.from_numpy(self.Theta[j, :]).float().to(self.device)
                    yT = tch_Theta_j @ tch_Omega - tch_ThetaThetaT[j, :] @ tch_O @ tch_H
                    yT @= tch_H.t()
                    yT *= c_z
                    y = yT.t()
                    Xy = tch_H.t() @ y
                    Xy = tch_H @ Xy
                    Xy = Xy.cpu().detach().numpy().reshape([-1, 1])
                    ## Others
                    X = self.HHT.transpose()
                    y = y.cpu().detach().numpy().reshape([-1, 1])
                    lam = self.lam2 / c_z**2
                    lam = lam.cpu().detach().tolist()
                    t_O_pre += time() - tic
                    # main update
                    tic = time()
                    # print(lam)
                    reg = linear_model.LassoLars(alpha=lam, precompute=gram, max_iter=10,
                                                 normalize=False, fit_intercept=False)
                    # print(X.shape, y.shape, Xy.shape)
                    reg.fit(X=X, y=y, Xy=Xy)
                    ## post
                    tch_Delta = torch.from_numpy(np.array(reg.coef_)).to(self.device)
                    tmp = tch_O[j, :] + tch_Delta
                    tch_O[j, :] = tmp / torch.clamp(torch.norm(tmp), min=1)
                    t_O_update += time() - tic
                    # end of inner
                # end of outter loop
            self.O[:] = tch_O.cpu().detach().numpy()
            # end of torch env
        self.log["time_O_pre"].append(t_O_pre)
        self.log["time_O_update"].append(t_O_update)
        # ene of method update_O_BCD_lasso

    def update_O_proximal_gradient(self):
        """
        Applying the proximal gradient algorithm to update O
        """
        tic_all = time()
        # common computation
        ThetaThetaT = torch_batch_matrix_mul_matrix_list(self.Theta, self.Theta.transpose(),
                                                         batch_size=self.eval_batch_size,
                                                         device=self.device)
        tch_ThetaThetaT = torch.from_numpy(ThetaThetaT).to(self.device)
        HOmegaTThetaT = torch_batch_matrix_list_mul_matrix([self.H, self.Omega.transpose()],
                                                           self.Theta.transpose(),
                                                           batch_size=self.eval_batch_size,
                                                           device=self.device)
        tch_HOmegaTThetaT = torch.from_numpy(HOmegaTThetaT).to(self.device)
        tch_O = torch.from_numpy(self.O).to(self.device)
        # step size computation
        tch_eta = torch.norm(tch_ThetaThetaT) * torch.tensor(self.F_HHT)
        tch_eta = torch.tensor(2.) / tch_eta
        self.log["eta"].append(tch_eta.cpu().detach().tolist())
        tch_thresh = torch.tensor(self.lam2) * tch_eta

        tch_O_pre = tch_O
        self.log["time_O_pre"].append(time()-tic_all)
        tic_all = time()
        for _ in range(self.O_loop_cnt):
            # gradient and gradient update
            grad = torch_batch_matrix_list_mul_matrix([self.H, self.H.transpose(),
                                                       self.O.transpose()],
                                                      ThetaThetaT,
                                                      batch_size=self.eval_batch_size,
                                                      device=self.device)
            tch_grad = torch.from_numpy(grad).to(self.device)
            tch_grad = tch_grad - tch_HOmegaTThetaT
            tch_O = tch_O - tch_eta * tch_grad.t()
            # proximal update
            tch_O = torch_wthresh(tch_O, tch_thresh)
            # O projection
            for j in range(self.dict_size):
                tch_O[j, :] = tch_O[j, :] / torch.clamp(torch.norm(tch_O[j, :]), min=1)
            # end of O loop
        self.O[:] = tch_O.cpu().detach().numpy()
        self.log["time_O_update"].append(time()-tic_all)

        if TB:
            self.writer.add_scalar('regret', self.log['regret'][-1],
                                   self.global_iter_cnt / self.eval_step)
            self.writer.add_embedding(tch_O, global_step=self.global_iter_cnt, tag="O")
            self.writer.add_histogram("tch_O", tch_O, global_step=self.global_iter_cnt, bins=100)
            self.writer.add_embedding(tch_grad, global_step=self.global_iter_cnt, tag="O_grad")
            self.writer.add_histogram("tch_grad", tch_grad,
                                      global_step=self.global_iter_cnt, bins=100)
            self.writer.add_embedding(tch_O - tch_O_pre,
                                      global_step=self.global_iter_cnt, tag="O_diff")
            self.writer.add_histogram("O_diff", tch_O - tch_O_pre,
                                      global_step=self.global_iter_cnt, bins=100)
            self.writer.add_scalars("O_row",
                                    {str(indx): val for indx, val in enumerate(tch_O[0, :200])},
                                    global_step=self.global_iter_cnt)
            self.writer.add_histogram("O_row_hist", tch_O[0, :],
                                      global_step=self.global_iter_cnt, bins=100)
            self.writer.add_scalars("O_row_grad",
                                    {str(indx):
                                     val for indx, val in enumerate(tch_grad.t()[0, :200])},
                                    global_step=self.global_iter_cnt)
        # end of method update_O_proximal_gradient

    def resample_O(self, method):
        tic = time()
        if method == "no": # pylint: disable=no-else-return
            return
        elif method not in ["uniform", "greedy"]:
            raise ValueError("O resampling method can only be no, uniform, greedy."
                             f"Invalid O resampling method {method}.")
        if self.verbose:
            print(f"Resampling O with {method}...")
        # get the index to resample
        utility = np.sum(np.abs(self.Theta), axis=1)
        O_rows_to_replace = np.where(utility <= self.util_thresh)[0]
        # get the a rows to feed. Idea from
        # https://stackoverflow.com/questions/58070203/
        # find-top-k-largest-item-of-a-list-in-original-order-in-python
        if method == "greedy":
            _, _, _, self.per_sample_regret = self.eval_metrics_torch(local_verbose=False)
            A_rows_to_feed = heapq.nlargest(len(O_rows_to_replace),
                                            enumerate(self.per_sample_regret),
                                            key=itemgetter(1))
            A_rows_to_feed = [ii for (ii, val) in A_rows_to_feed]
        else:
            raise ValueError(f"Unrecognized method name {method}.")
        # excecute the reample
        for o_row, a_row in zip(O_rows_to_replace, A_rows_to_feed):
            self.O[o_row, :] = self.A[a_row, :].toarray() +\
                np.random.normal(0, SDMP_dict.NORMAL_VAR, (1, self.O.shape[1]))
        self.log['time_resample_O_execute'].append(time()-tic)
        tic = time()
        # roughly refill the affected Theta
        if len(A_rows_to_feed) > 0:
            self.update_Theta(A_rows_to_feed, log=False)
        self.log['time_resample_O_post'].append(time()-tic)
        self.log["O_resampled"].append(O_rows_to_replace)
        self.log["A_to_fill"].append(A_rows_to_feed)
        if DEBUG:
            self.log["O_init"][-1][O_rows_to_replace, :] = self.O[O_rows_to_replace, :]

    def post_update(self):
        if self.verbose:
            print("Post precessing... Recomputing Theta...")
        tic_start = time()
        for it, cur_idx in enumerate(self._sampler()):
            self.update_Theta(cur_idx, log=False)
            if self.verbose:
                cur_time = time() - tic_start
                ETA = cur_time / (it + 1) * (self._max_iter - it - 1)
                print(f'{it/self._max_iter:.3f} finished in {cur_time:.1f} s. ETA: {ETA:.1f} s',
                      end='\r', flush=True)
        if self.verbose:
            print()

    def get_inference(self):
        OH = torch_batch_matrix_mul_matrix_list(self.O, self.H,
                                                batch_size=self.eval_batch_size,
                                                device=self.device)
        return torch_batch_matrix_mul_matrix_list(self.Theta.transpose(),
                                                  OH,
                                                  batch_size=self.eval_batch_size,
                                                  device=self.device)

    def save(self, res_folder):
        with open(os.path.join(res_folder, "O_Theta.pkl"), 'wb') as fout:
            pickle.dump([self.O, self.Theta], fout)
        with open(os.path.join(res_folder, "log.pkl"), "wb") as fout:
            pickle.dump(self.log, fout)
        if self.verbose:
            print(f"Results are saved in {res_folder}.")

    def load(self, res_folder):
        if self.verbose:
            print(f"Loading results from {res_folder}.")
        with open(os.path.join(res_folder, "O_Theta.pkl"), 'rb') as fin:
            [self.O, self.Theta] = pickle.load(fin)
        with open(os.path.join(res_folder, "log.pkl"), "rb") as fin:
            self.log = pickle.load(fin)
        cur_O_shape = self.O.shape
        if self.dict_size != cur_O_shape[0]:
            raise AttributeError("Dict size from loaded results does not match the configuration"
                                 f"{self.dict_size} != {cur_O_shape}")

class SparseDictionaryLearning: # pylint: disable=too-many-instance-attributes
    """
    Solve the sparse dictionary learning problem with potential weight. Given signal
    A={A1^T; A2^T;...; An^T} (n by F), whose row vectors Ai^T are signals. We wish to learn a
    dictionary O (c by n), whose rows are base vector, along with coefficient a={a1, a2, ..., ac}
    with ai (n), so that we can appoximate Ai by Ai^T=ai^T O. To find O and a, we need to solve the
    optimization problem (Original)
        minimize_{O, a} \sum_z 1/2|a_z^T O - Az^T|_2^2+\lambda |a_z|_1,
        s.t. O_{i*}O_{i*}^T \le 1, \forall i.
    Another variant is to reweight the objective with some given signal Y. And we wich to
    solve (Weighted)
        minimize_{O, a} \sum_z 1/2|a_z^T O Y- Az^T Y|_2^2+\lambda |a_z|_1,
        s.t. O_{i*}O_{i*}^T \le 1, \forall i.
    This class solves the above two problems.

    Variables:
    -----------
        O: the dictionary
        a: the coefficients
        log: training log for various statistics

    Methods:
    -----------

    """
    Q_INIT_DIAG_VAL = 1.0e-6 # small value to initialize Q diagnal
    O_NONACTIVE_THRESHOLD = 1.0e-3 # threshold to detect nonactive row of O
    NORMAL_VAR = 0.01

    def __init__(self, # pylint: disable=too-many-arguments, too-many-statements
                 A,
                 dict_size,
                 F=None,
                 weighted=True,
                 epoch=5,
                 batch_size=64,
                 eval_step=20,
                 eval_batch_size=-1,
                 a_method='lasso_lars',
                 lam=1.0e-6,
                 n_a_nonzero=20,
                 shuffle=True,
                 num_worker=1,
                 O_Q_ST_accurate=True,
                 O_loop_cnt=1,
                 O_init_method='random_select',
                 O_resample_method='no',
                 O_resample_warmup=5,
                 O_resample_step=20,
                 device="cpu",
                 verbose=True,
                 **kwargs):
        """
        Parameters:
        -----------
            A: (normalized) adjacent matrix representing the graph
            dict_size: how many base to use in the dictionary
            F: Reweiting matrix for the weighted version
            weighted: swtich the mode between the Original solver or the weighted solver
            epoch: maximum number of training epoch
            batch_size: batch size in each training iteration
            eval_step: how many iterations to run before an evaluation step
            eval_batch_size: evaluation batch size, -1 for full batch
            a_method: optimizer to solve the a in phase a. lars, lasso_lars
            lam: lasso regularizer for solving a when the a_method is lasso_lars
            n_a_nonzero: the number of nonzeros for solving a when the a_method is lars
            shuffle: whether shuffle the sampling order in each epoch
            num_worker: number of works to compute phase a
            O_Q_ST_accurate: use SGD cache or accurate Q ST for optimizing O
            O_loop_cnt: iteration in optimizing the O phase
            O_init_method: initialization method for O
            O_resample_method: approach to resample unused rows of O. 'no' for no resampling,
                               'uniform' for uniform resampling, 'greedy' for greedily resample
                               the rows that a has the largest loss.
            O_resample_warmup: warm up iterations before execute the resampling of O
            O_resample_step: how many iterations to execte before the resampling O
            verbose: whether to display intermeidate results
        """
        self.A = A
        self.dict_size = dict_size
        self.F = F
        self.weighted = weighted
        self.epoch = epoch
        self.batch_size = batch_size
        self.eval_step = eval_step
        self.eval_batch_size = eval_batch_size
        self.a_method = a_method
        self.lam = lam
        self.n_a_nonzero = n_a_nonzero
        self.shuffle = shuffle
        self.num_worker = num_worker
        self.O_Q_ST_accurate = O_Q_ST_accurate
        self.O_loop_cnt = O_loop_cnt
        self.O_resample_method = O_resample_method
        self.O_resample_warmup = O_resample_warmup
        self.O_resample_step = O_resample_step
        self.O_init_method = O_init_method
        self.verbose = verbose
        self.device = device
        self.kwargs = kwargs

        if weighted:
            if self.verbose:
                print("Weighted version.")
            if self.F is None:
                raise ValueError("F should not be none under weighted mode! ")
            if "AF" in self.kwargs:
                self.AF = self.kwargs["AF"]
            else:
                self.AF = A.dot(F)
            self.phase_a = self.phase_a_weighted
        else:
            if self.verbose:
                print("Original version.")
            self.phase_a = self.phase_a_original

        self.data_size, self.feature_dim = self.A.shape # feature dimension and the data size
        self.global_iter_cnt = None
        self.log = defaultdict(list)
        self._iter_cnt, self._sample_idx = None, None
        self._max_iter = int(np.ceil(self.data_size / self.batch_size))

        self.lam /= self.feature_dim # renormalization

        self.O, self.a, self.a_old = None, None, None # dictionary, coefficients, coefficients cache
        self.Q, self.ST = None, None
        self.per_sample_regret = None # per sample regret to pick resampled O

        self._check_param()

    def _check_param(self):
        if self.dict_size > self.feature_dim:
            raise ValueError("Does not support the case "
                             "that dict size is larger than feature dimension! "
                             f"Current dict size {self.dict_size} and "
                             f"current feature dimension {self.feature_dim}")

    def _sampler(self):
        """handle the random mini-batch indices"""
        # initialize the parameters
        self._sample_idx = np.arange(self.data_size)
        if self.shuffle:
            np.random.shuffle(self._sample_idx)
        # main loop
        for self._iter_cnt in range(self._max_iter):
            yield self._sample_idx[self._iter_cnt*self.batch_size:
                                   min((self._iter_cnt+1)*self.batch_size, self.data_size)]

    @staticmethod
    def eval_metrics_torch(O, a, A, F, eval_batch_size=-1, AF=None, device="cpu", verbose=True): # pylint: disable=too-many-locals
        """
        Evaluating the performance metrics
            All written together to boost the efficiency and reduce redundant computation
        """
        if verbose:
            print("---- Evaluating...")

        if F is not None and AF is None:
            raise ValueError("AF should be provided if F is not none. ")

        _, feature_dim = A.shape
        regret, rel_regret, weighted_regret, rel_weighted_regret = np.NAN, np.NAN, np.NAN, np.NAN

        list_DRSS, list_SNSS, list_WDRSS, list_WSNSS = [], [], [], []
        if eval_batch_size == -1:
            eval_batch_size = feature_dim
        max_batch = int(np.ceil(feature_dim/eval_batch_size))

        with torch.no_grad():
            tch_O = torch.from_numpy(O).float().to(device)
            tch_F = torch.from_numpy(F).float().to(device)

            t_start = time()
            for bb in range(max_batch):
                cur_a = torch.from_numpy(a[bb*eval_batch_size:min((bb+1)*eval_batch_size,
                                           feature_dim), :]).float().to(device)
                cur_A = torch.from_numpy(A[bb*eval_batch_size:min((bb+1)*eval_batch_size,
                                         feature_dim)].toarray()).float().to(device)
                a_dot_O = torch.mm(cur_a, tch_O)
                # regret
                diff_row_square_sum = torch.norm(a_dot_O - cur_A, dim=1) ** 2
                self_norm_square_sum = torch.norm(cur_A, dim=1) ** 2
                list_DRSS.append(diff_row_square_sum.cpu().detach().numpy())
                list_SNSS.append(self_norm_square_sum.cpu().detach().numpy())
                # weighted regret
                if F is not None:
                    cur_AF =\
                        torch.from_numpy(
                            AF[bb*eval_batch_size:min((bb+1)*eval_batch_size,
                                    feature_dim), :]).float().to(device)
                    a_dot_O_dot_F = torch.mm(a_dot_O, tch_F)
                    weighted_diff_row_square_sum = torch.norm(a_dot_O_dot_F - cur_AF, dim=1) ** 2
                    weighted_self_norm_square_sum = torch.norm(cur_AF, dim=1) ** 2
                    list_WDRSS.append(weighted_diff_row_square_sum.cpu().detach().numpy())
                    list_WSNSS.append(weighted_self_norm_square_sum.cpu().detach().numpy())
                if verbose:
                    cur_time = time() - t_start
                    ETA = cur_time / (bb+1) * (max_batch - bb - 1)
                    sys.stdout.write(f"{bb+1}/{max_batch} finished in {cur_time:.1f}s, "
                                     f"ETA {ETA:.1f}\r")
                    sys.stdout.flush()
            if verbose:
                print()
            DRSS = np.concatenate(list_DRSS)
            SNSS = np.concatenate(list_SNSS)
            regret = np.sum(DRSS)
            per_sample_rel_regret = DRSS / SNSS
            rel_regret = np.mean(per_sample_rel_regret)
            # weighted regret
            if F is not None:
                WDRSS = np.concatenate(list_WDRSS)
                WSNSS = np.concatenate(list_WSNSS)
                weighted_regret = np.sum(WDRSS)
                per_sample_rel_weighted_regret = WDRSS / WSNSS
                rel_weighted_regret = np.mean(per_sample_rel_weighted_regret)
                cached_per_sample_regret = per_sample_rel_weighted_regret
            else:
                cached_per_sample_regret = per_sample_rel_regret
        return regret, rel_regret, weighted_regret, rel_weighted_regret, cached_per_sample_regret

    @staticmethod
    def _init_O_a(dict_size, target, method="random_svd"):
        data_size = target.shape[0]
        if "random_select" == method:
            seed = np.random.choice(list(range(data_size)), dict_size, replace=False)
            O = target[seed]
            if not isinstance(target, np.ndarray):
                O = O.toarray()
            O += np.random.normal(0, SparseDictionaryLearning.NORMAL_VAR, O.shape)
            a = np.zeros([data_size, dict_size])
        elif "random_svd" == method:
            a, S, dictionary = randomized_svd(target, dict_size, random_state=0)
            O = S[:, np.newaxis] * dictionary
        else:
            raise ValueError("O initialization method could be random_select, random_svd."
                             f"Undefined method for initilize O: {method}.")
        return O, a

    @staticmethod
    def _init_QS(dict_size, feature_dim):
        Q = SparseDictionaryLearning.Q_INIT_DIAG_VAL * np.identity(dict_size)
        S = np.zeros([dict_size, feature_dim])
        return Q, S

    @staticmethod
    def _compute_a(v, O, gram, Xy, method='lasso_lars', lam=1.0e-3, n_a_nonzero=20):
        """
        Keep this structure for multi-processing extension
        """
        if method == "lasso_lars":
            reg = linear_model.LassoLars(alpha=lam, precompute=gram,
                                         normalize=False, fit_intercept=False)
            reg.fit(X=O, y=v, Xy=Xy)
            res = reg.coef_
        elif method == "max_a_nonzero":
            if isinstance(n_a_nonzero, list):
                max_iter = n_a_nonzero[-1]
            else:
                max_iter = n_a_nonzero
            reg = linear_model.LassoLars(alpha=1e-7, precompute=gram, max_iter=max_iter,
                                         normalize=False, fit_intercept=False)
            reg.fit(X=O, y=v, Xy=Xy)
            res = reg.coef_
        else:
            raise ValueError("a phase method could only be lasso_lars, lars."
                             f"Unrecognized a method {method}.")
        return res

    @staticmethod
    def _thread_wrapper(p_str_res, ii, v, O, gram, Xy, method, lam, n_a_nonzero): # pylint: disable=too-many-arguments
        res = SparseDictionaryLearning._compute_a(v,
                                                  O,
                                                  gram,
                                                  Xy,
                                                  method=method,
                                                  lam=lam,
                                                  n_a_nonzero=n_a_nonzero)
        p_str_res.value = pickle.dumps([ii, res], protocol=0)


    def compute_a(self, v, O, method='lasso_lars', lam=1.0e-3, n_a_nonzero=20,
                  num_worker=1):
        num_a, _ = v.shape
        # precomputation to speed up lars solver
        ## *** Normalize the necessary variables to remove the repeated normalization
        ## Need to scale everyting back after the optimization ***
        X = O.transpose()
        y = v.transpose()
        # X, y, X_offset, y_offset, X_scale = linear_model._base._preprocess_data(
        #                                         X, y, True, True, True
        #                                     ) # fit_intercept, normalize, copy_X
        gram = torch_batch_matrix_mul_matrix_list(X.transpose(), X, batch_size=self.eval_batch_size,
                                                  device=self.device)
        Xy = np.dot(X.transpose(), y)

        if num_worker == 1:
            res = SparseDictionaryLearning._compute_a(y,
                                                      X,
                                                      gram,
                                                      Xy,
                                                      method=method,
                                                      lam=lam,
                                                      n_a_nonzero=n_a_nonzero)
            res = np.array(res)
        else:
            manager = Manager()
            res_cache = []
            my_procs = []
            block_size = int(np.ceil((num_a / num_worker)))
            for ii in range(num_worker):
                res_cache.append(manager.Value(c_char_p, ""))
                my_y = y[:, ii*block_size:min(num_a, (ii+1)*block_size)]
                my_procs.append(Process(target=SparseDictionaryLearning._thread_wrapper,
                                        args=(res_cache[-1],
                                              ii,
                                              my_y,
                                              X,
                                              gram,
                                              Xy,
                                              method,
                                              lam,
                                              n_a_nonzero)))
            for i in range(num_worker):
                my_procs[i].start()
            for i in range(num_worker):
                my_procs[i].join()
            res_ind, res = [], []
            for i in range(num_worker):
                tmp_ind, tmp_res = pickle.loads(res_cache[i].value)
                res_ind.append(tmp_ind)
                res.append(tmp_res)
            ind_order = np.argsort(res_ind)
            res = [res[ind_order[i]] for i in range(num_worker)]
            res = np.concatenate(res, axis=0)

            # end of multi worker processing

        # *** Project the normalized variables back ***

        return res

    @staticmethod
    def update_O_torch(O, Q, ST, O_loop_cnt=1, device="cpu"):
        dict_size, _ = O.shape

        with torch.no_grad():
            torch_O = torch.from_numpy(O).float().to(device)
            torch_Q = torch.from_numpy(Q).float().to(device)
            torch_ST = torch.from_numpy(ST).float().to(device)
            for _ in range(O_loop_cnt):
                for j in range(dict_size):
                    tmp = (torch_ST[j, :] - torch_Q[j, :] @ torch_O) / torch_Q[j, j] + torch_O[j, :]
                    torch_O[j, :] = tmp / torch.clamp(torch.norm(tmp), min=1)
        O[:] = torch_O.cpu().detach().numpy()

    def eval_and_log(self):
        # this_regret, this_rel_regret, this_weighted_regret, this_rel_weighted_regret,\
        #     self.per_sample_regret = self.eval_metrics(self.O, self.a, self.A, self.F,
        #                                                eval_batch_size=self.eval_batch_size)
        this_regret, this_rel_regret, this_weighted_regret, this_rel_weighted_regret,\
            self.per_sample_regret = self.eval_metrics_torch(self.O, self.a, self.A, self.F,
                                                             eval_batch_size=self.eval_batch_size,
                                                             AF = self.AF, device=self.device,
                                                             verbose=self.verbose)
        self.log["regret"].append(this_regret)
        self.log["rel_regret"].append(this_rel_regret)
        self.log["weighted_regret"].append(this_weighted_regret)
        self.log["rel_weighted_regret"].append(this_rel_weighted_regret)

        self.log["a_row_nonzero"].append(np.sum(self.a!=0, axis=1))
        self.log["O_row_nonzero"].append(np.sum(self.O!=0, axis=1))

    def _display_stat(self):
        print(f"Weighted regret: {self.log['weighted_regret'][-1]:.6f},"
              f" | Rel weighted regret: {self.log['rel_weighted_regret'][-1]:.6f}"
              f" | Regret: {self.log['regret'][-1]:.6f}"
              f" | Rel Regret: {self.log['rel_regret'][-1]:.6f}")
        print(f"  a row nonzeros: {np.mean(self.log['a_row_nonzero'][-1]):.4f} " + u"\u00B1"
              f" {np.std(self.log['a_row_nonzero'][-1]):.4f}"
              f" | O row nonzeros: mean {np.mean(self.log['O_row_nonzero'][-1]):.4f} " + u"\u00B1"
              f" {np.std(self.log['O_row_nonzero'][-1]):.4f}")

    def fit(self): # pylint: disable=too-many-branches, too-many-statements
        """Pure fit without O sparsification."""
        tic_all_start = time()
        # initialization
        if self.verbose:
            print("Initializing...")
        self.O, self.a = self._init_O_a(self.dict_size, self.A, method=self.O_init_method)
        if DEBUG:
            self.log["O_init"].append(self.O.copy())
        if self.O_Q_ST_accurate:
            self.a_old = self.a.copy()
        self.Q, self.ST = self._init_QS(self.dict_size, self.feature_dim)

        self.global_iter_cnt = 0
        total_iter = self.epoch * self._max_iter
        # main loop
        self.eval_and_log()
        if self.verbose:
            print("Training started...")
            self._display_stat()
            print("-"*27)
            print()
        tic_start = time()
        for e in range(self.epoch):
            for it, cur_idx in enumerate(self._sampler()):
                tic_iter_start = time()
                # phase a
                tic_iter_a_start = time()
                cur_a, cur_v = self.phase_a(cur_idx)
                self.log["time_compute_a"].append(time() - tic_iter_a_start)
                # phase O
                tic_iter_O_start = time()
                self.update_state_for_O(cur_a, cur_v, cur_idx)
                if DEBUG:
                    time_O_update = time() - tic_iter_O_start
                # self.update_O(self.O, self.Q, self.ST, O_loop_cnt=self.O_loop_cnt)
                self.update_O_torch(self.O, self.Q, self.ST,
                                    O_loop_cnt=self.O_loop_cnt, device=self.device)
                self.log["time_compute_O"].append(time() - tic_iter_O_start)

                self.log["time_iter"].append(time() - tic_iter_start)

                if DEBUG:
                    print(f"*** Time a {self.log['time_compute_a'][-1]:.1f} | "
                          f"Time O {self.log['time_compute_O'][-1]:.1f} | "
                          f"Time O update {time_O_update:.1f}")

                self.global_iter_cnt += 1
                if self.global_iter_cnt % self.O_resample_step == 0:
                    self.resample_O(self.O_resample_method)
                # evaluation and display
                if self.global_iter_cnt % self.eval_step == 0:
                    self.eval_and_log()
                    if self.verbose:
                        elapsed_time = time() - tic_start
                        ETA = elapsed_time /\
                            self.global_iter_cnt * (total_iter - self.global_iter_cnt)
                        print(f"Epoch: {e} | Iter: {it} | "
                              f"Elapsed time: {elapsed_time:.1f} | ETA: {ETA:.1f}")
                        self._display_stat()

        # Post update for the parameters
        if self.verbose:
            print("-"*27)
            print()
            print("Started post processing...")
        self.post_update()
        self.eval_and_log()
        if self.verbose:
            print(f"Training finished in {(time()-tic_all_start):.1f} s.")
            self._display_stat()
        return self

    def post_update(self, O_is_sparsify=False, O_sparsify_conf=None):
        """
        Post processing for the weights
        """
        if O_is_sparsify:
            # use percentile thresholing instead of the full version
            if self.verbose:
                print("Using O sparsify mode. Sparsifing O...")
            self.O, signal_list = self.sparsify_O(O_sparsify_conf)
            self.log["signal_list"] = signal_list
        if self.verbose:
            print("Recompute a...")
        tic_start = time()
        for it, cur_idx in enumerate(self._sampler()):
            self.phase_a(cur_idx)
            if self.verbose:
                cur_time = time() - tic_start
                ETA = cur_time / (it + 1) * (self._max_iter - it - 1)
                print(f'{it/self._max_iter:.3f} finished in {cur_time:.1f} s. ETA: {ETA:.1f} s',
                      end='\r', flush=True)
        if self.verbose:
            print()

    def sparsify_O(self, O_sparsify_conf):
        """
        sparsify O according to some approach
        """
        def get_remove_idx(O_row, O_sparsify_conf):
            cur_val = np.abs(O_row)
            cur_val = cur_val / np.sum(cur_val)
            sort_idx = np.argsort(-cur_val)
            cur_val = cur_val[sort_idx]
            i, nnz, signal =0, 0, 0
            while nnz < O_sparsify_conf["max_nonzero"] and\
                signal < O_sparsify_conf["signal_ratio"]:
                nnz += 1
                signal += cur_val[i]
                i += 1
            return sort_idx[i:], signal

        new_O = self.O.copy()
        c, _ = new_O.shape
        signal_list = []
        for i in range(c):
            to_remove, signal = get_remove_idx(new_O[i,:], O_sparsify_conf)
            new_O[i, to_remove] = 0
            signal_list.append(signal)
        return new_O, signal_list

    def phase_a_original(self, cur_idx):
        cur_v = self.A[cur_idx, :].toarray()
        cur_a = self.compute_a(cur_v, self.O, method=self.a_method,
                               lam=self.lam, n_a_nonzero=self.n_a_nonzero,
                               num_worker=self.num_worker)
        self.a[cur_idx, :] = cur_a
        return cur_a, cur_v

    def phase_a_weighted(self, cur_idx):
        cur_vF = self.AF[cur_idx, :]
        cur_v = self.A[cur_idx,:].toarray()
        # target_O = self.O.dot(self.F)
        target_O = torch_batch_matrix_mul_matrix_list(self.O, self.F,
                                                      batch_size=self.eval_batch_size,
                                                      device=self.device)
        cur_a = self.compute_a(cur_vF, target_O, method=self.a_method,
                               lam=self.lam, n_a_nonzero=self.n_a_nonzero,
                               num_worker=self.num_worker)
        self.a[cur_idx, :] = cur_a
        return cur_a, cur_v

    def _update_state_for_O_accurate(self, cur_a, cur_v, cur_idx):
        cur_a_old = self.a_old[cur_idx, :]
        self.Q += cur_a.T @ cur_a - cur_a_old.T @ cur_a_old
        self.ST += (cur_a - cur_a_old).T @ cur_v
        self.a_old[cur_idx, :] = cur_a

    def _update_state_for_O_approximate(self, cur_a, cur_v):
        if self.global_iter_cnt < self.batch_size:
            theta = self.global_iter_cnt * self.batch_size
        else:
            theta = self.batch_size ** 2 + self.global_iter_cnt - self.batch_size
        beta = (theta + 1 - self.batch_size) / (theta + 1)

        self.Q = self.Q * beta + cur_a.T @ cur_a
        self.ST = self.ST * beta + cur_a.T @ cur_v

    def update_state_for_O(self, cur_a, cur_v, cur_idx):
        """
        aggregate import informtion for O phase
            The detailed algorithm is from Oneline Dictionary Learning for Sparse Coding
        """
        if self.O_Q_ST_accurate:
            self._update_state_for_O_accurate(cur_a, cur_v, cur_idx)
        else:
            self._update_state_for_O_approximate(cur_a, cur_v)

    def resample_O(self, method):
        if method == "no": # pylint: disable=no-else-return
            return
        elif method not in ["uniform", "greedy"]:
            raise ValueError("O resampling method can only be no, uniform, greedy."
                             f"Invalid O resampling method {method}.")
        if self.global_iter_cnt <= self.O_resample_warmup:
            self.log["num_resampled_O"].append(0)
            return
        if self.verbose:
            print("Resampling O...")
        # get the index to resample
        utility = np.sum(np.abs(self.a), axis=0)
        O_rows_to_replace = np.where(utility <= SparseDictionaryLearning.O_NONACTIVE_THRESHOLD)[0]
        # get the a rows to feed. Idea from
        # https://stackoverflow.com/questions/58070203/
        # find-top-k-largest-item-of-a-list-in-original-order-in-python
        if method == "greedy":
            A_rows_to_feed = heapq.nlargest(len(O_rows_to_replace),
                                            enumerate(self.per_sample_regret),
                                            key=itemgetter(1))
            A_rows_to_feed = [ii for (ii, val) in A_rows_to_feed]
        elif method == "uniform":
            A_rows_to_feed = np.random.choice(list(range(self.data_size)),
                                              size=len(O_rows_to_replace), replace=False)
        # excecute the reample
        for o_row, a_row in zip(O_rows_to_replace, A_rows_to_feed):
            self.O[o_row, :] = self.A[a_row, :].toarray() +\
                np.random.normal(0, SparseDictionaryLearning.NORMAL_VAR, (1, self.O.shape[1]))
            # roughly fill the a to prevent degenerating cases
            self.a[a_row, :] = 0
            self.a[a_row, o_row] = 1.
            if DEBUG:
                self.log["O_init"][-1][o_row, :] = self.O[o_row, :]
        # update inner state
        if self.O_Q_ST_accurate:
            self._update_state_for_O_accurate(self.a[A_rows_to_feed, :],
                                              self.A[A_rows_to_feed, :],
                                              A_rows_to_feed)

        self.log["num_resampled_O"].append(len(O_rows_to_replace))

    def save(self, res_folder):
        with open(os.path.join(res_folder, "dict.pkl"), 'wb') as fout:
            pickle.dump([self.O, self.a], fout)
        with open(os.path.join(res_folder, "log.pkl"), "wb") as fout:
            pickle.dump(self.log, fout)
        if self.verbose:
            print(f"Results are saved in {res_folder}.")

    def load(self, res_folder):
        if self.verbose:
            print(f"Loading results from {res_folder}.")
        with open(os.path.join(res_folder, "dict.pkl"), 'rb') as fin:
            [self.O, self.a] = pickle.load(fin)
        with open(os.path.join(res_folder, "log.pkl"), "rb") as fin:
            self.log = pickle.load(fin)
        cur_O_shape = self.O.shape
        if self.dict_size != cur_O_shape[0]:
            raise AttributeError("Dict size from loaded results does not match the configuration"
                                 f"{self.dict_size} != {cur_O_shape}")

    ###########
    ## Unused methods
    @staticmethod
    def update_O(O, Q, ST, O_loop_cnt=1):
        dict_size, _ = O.shape
        for _ in range(O_loop_cnt):
            for j in range(dict_size):
                tmp = (ST[j, :] - np.dot(Q[j, :], O)) / Q[j, j] + O[j, :]
                O[j, :] = tmp / max([1, np.linalg.norm(tmp)])

    @staticmethod
    def weighted_regret(O, a, A, F):
        return np.linalg.norm(np.dot(a, O).dot(F)-A.dot(F)) ** 2

    @staticmethod
    def rel_weighted_regret(O, a, A, F):
        res = []
        total, _ = a.shape
        appro = np.dot(a, O).dot(F)
        src = A.dot(F)
        for idx in range(total):
            self_norm = np.linalg.norm(src[idx, :]) ** 2
            diff_norm = np.linalg.norm(src[idx, :] - appro[idx, :]) ** 2
            res.append(diff_norm/self_norm)
        return np.mean(res)

    @staticmethod
    def regret(O, a, A):
        return np.linalg.norm(np.dot(a, O)-A) ** 2

    @staticmethod
    def rel_regret(O, a, A):
        res = []
        total, _ = a.shape

        appro = np.dot(a, O)
        for idx in range(total):
            self_norm = scipy.sparse.linalg.norm(A[idx, :]) ** 2
            diff_norm = np.linalg.norm(A[idx, :] - appro[idx, :]) ** 2
            res.append(diff_norm/self_norm)
        return np.mean(res)


class SMFRProxLinX(SparseDictionaryLearning): # pylint: disable=too-many-instance-attributes
    """
    TODO: cache repeated computation of X'X, X'Y etc.
    Reduce

    SMFR with proxy linear update
    Inputs:
    - X: Matrix of predictors (n x p)
    - Y: Matrix of responses (n x q)
    - lam1: Regularization factor for ||A||_1
    - lam2: Regularization factor for ||B||_1
    - lam3: Regularization factor for ||A||^2_F
    - nFactorsInit: Initial number of factors. Algorithm will start from
    this number and reduce it until B has full row rank
    Outputs:
    - A: Matrix deriving factors from inputs (p x m)
    - B: Matrix of regression coefficients from factors to outputs (m x q)
    - nFactors: Estimated number of factors
    """
    TOL = 1e-6
    MAX_ITER = 10000
    def __init__(self, # pylint: disable=too-many-arguments, super-init-not-called
                 X,
                 Y,
                 lam1,
                 lam2,
                 lam3,
                 num_factor_init,
                 adj=None,
                 device="cpu",
                 display_step=20,
                 verbose=True):
        self.X = X
        self.Y = Y
        self.lam1 = lam1
        self.lam2 = lam2
        self.lam3 = lam3
        self.num_factor_init = num_factor_init
        self.adj = adj
        self.device = device
        self.display_step = display_step
        self.verbose = verbose

        self.num_preds = X.shape[1]
        self.num_resps = Y.shape[1]

        self.A, self.B, self.num_factors = None, None, None
        self.log = defaultdict(list)

    def get_obj(self, X, Y, A, B):
        return 0.5 * np.linalg.norm(Y - X @ A @ B) ** 2 + self.lam1 * np.sum(np.abs(A)) +\
                  self.lam2 * np.sum(np.abs(B)) + self.lam3 * np.linalg.norm(A)

    def fit(self): # pylint: disable=too-many-statements
        """
        adj: the related target graph matrix to evaluate graph related metrics.
        """
        tic_main = time()
        for num_factors in range(self.num_factor_init, 0, -1):
            if self.verbose:
                print(f"Factor: {num_factors} started...")
            # tic_factor = time()
            A = np.random.rand(self.num_preds, num_factors)
            B = np.random.rand(num_factors, self.num_resps)
            Aprev = A.copy()
            Bprev = B.copy()
            obj = self.get_obj(self.X, self.Y, A, B)
            t0 = 1

            XTX = self.X.transpose() @ self.X
            XTY = self.X.transpose() @ self.Y
            norm_XTX = np.linalg.norm(XTX)
            for i in range(self.MAX_ITER):
                tic_iter = time()
                # --- updating B ---
                t = (1 + np.sqrt(1 + 4 * t0 ** 2)) / 2
                wB = (t0 - 1)/t
                t0=t
                # LB = np.linalg.norm(A.transpose() @ (self.X.transpose() @ self.X) @ A) # X'X
                LB = np.linalg.norm(A.transpose() @ XTX @ A)
                Bhat = B + wB * (B - Bprev)
                # Ghat = (A.transpose() @ self.X.transpose() @ (self.X @ A @ Bhat - self.Y))
                # X'X and X'Y
                Ghat = A.transpose() @ (XTX @ A @ Bhat - XTY) # X'X and X'Y
                Bnew = np_wthresh(Bhat - Ghat / LB, self.lam2 / LB)
                obj_prev = obj
                obj = self.get_obj(self.X, self.Y, A, Bnew)
                Bprev=B.copy()

                # --- if no descent, repeat with no extrapolation ---
                if obj >= obj_prev:
                    # Ghat = (A.transpose() @ self.X.transpose() @ (self.X @ A @ B - self.Y))
                    # X'X and X'Y
                    Ghat = A.transpose() @ (XTX @ A @ B - XTY) # X'X and X'Y
                    B = np_wthresh(B - Ghat / LB, self.lam2 / LB)
                else:
                    B=Bnew.copy()
                # --- updating A ---
                # t = (1 + np.sqrt(1 + 4 * t0 ** 2)) / 2
                # wA = (t0 - 1) / t
                # t0=t
                wA = wB # this is might be the better one
                # LA = np.linalg.norm(self.X.transpose() @ self.X) *\
                #      np.linalg.norm(B @ B.transpose()) + 2 * self.lam3 # X'X norm of it
                LA = norm_XTX *\
                     np.linalg.norm(B @ B.transpose()) + 2 * self.lam3 # X'X norm of it
                Ahat = A + wA * ( A - Aprev)
                # Ghat = - self.X.transpose()  @ self.Y @ B.transpose() +\ # X'X and X'Y, B combine
                #        (self.X.transpose() @ self.X) @ Ahat @ (B @ B.transpose()) +\
                #        2 * self.lam3 * Ahat
                Ghat = (- XTY + XTX @ Ahat @ B) @ B.transpose() + 2 * self.lam3 * Ahat
                Anew = np_wthresh(Ahat - Ghat / LA, self.lam1 / LA)
                obj_prev = obj
                obj = self.get_obj(self.X, self.Y, Anew, B) # tiny difference on the lam3 part.
                Aprev = A.copy()
                # --- if no descent, repeat with no extrapolation ---
                if obj >= obj_prev:
                    # Ghat = - self.X.transpose() @ self.Y @ B.transpose() +\
                    #        # X'X and X'Y, B combine
                    #        (self.X.transpose() @ self.X) @ A @ (B @ B.transpose()) +\
                    #        2 * self.lam3 * A
                    Ghat = (-XTY + XTX @ A @ B )@ B.transpose() + 2 * self.lam3 * A
                    A = np_wthresh(A - Ghat / LA, self.lam1 / LA)
                else:
                    A = Anew.copy()

                if LA == 0 or LB == 0:
                    print("Catch it!!!!!!!!!!!!!!!!!! ")
                    print(f"Factor: {num_factors}, iter: {i}, obj:{obj:.4f} "
                          f"iter time: {time()-tic_iter:.1f} s, "
                          f"elapsed time: {time()-tic_main:.1f} s.")
                    print(A)

                if self.verbose and (i+1) % self.display_step == 0:
                    print(f"Factor: {num_factors}, iter: {i}, obj:{obj:.4f} "
                          f"iter time: {time()-tic_iter:.1f} s, "
                          f"elapsed time: {time()-tic_main:.1f} s. norm A "
                          f"{np.linalg.norm(A):.1f} norm B {np.linalg.norm(B):.1f}")
#                     if self.adj is not None:
#                         this_regret, this_rel_regret, this_weighted_regret,\
#                             this_rel_weighted_regret, per_sample_regret =\
#                             SparseDictionaryLearning.eval_metrics_torch(
#                                 A.transpose(), B.transpose(), self.adj, self.X.transpose(),
#                                 eval_batch_size=-1,
#                                 AF = self.Y.transpose(), device=self.device,
#                                 verbose=self.verbose)
#                         self.log["regret"].append(this_regret)
#                         self.log["rel_regret"].append(this_rel_regret)
#                         self.log["weighted_regret"].append(this_weighted_regret)
#                         self.log["rel_weighted_regret"].append(this_rel_weighted_regret)

#                         self.log["a_row_nonzero"].append(np.sum(B.transpose()!=0, axis=1))
#                         self.log["O_row_nonzero"].append(np.sum(A.transpose()!=0, axis=1))
#                         self._display_stat()
                # --- stopping check ---
                # in the original implementation this might be troublesome
                # since obje_prev has been updated twice already, it's
                obj = self.get_obj(self.X, self.Y, A, B)
                diff = np.abs(obj - obj_prev) / (obj_prev + 1e-8)
                self.A = A
                self.B = B
                if diff < self.TOL:
                    break
                # end of iterative update for a fixed factor
            if (np.linalg.matrix_rank(B) == num_factors) and\
               (np.linalg.matrix_rank(A) == num_factors):
                break
            # end of iteration on factors
            self.A = A
            self.B = B
            self.num_factors = num_factors
        return A, B, num_factors
